#!/usr/bin/env python3

from typing import Optional, Union

import torch
from torch import Tensor

from .. import settings
from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel


class CylindricalKernel(Kernel):
    r"""
    Computes a covariance matrix based on the Cylindrical Kernel between
    inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`.
    It was proposed in `BOCK: Bayesian Optimization with Cylindrical Kernels`.
    See http://proceedings.mlr.press/v80/oh18a.html for more details

    .. note::
        The data must lie completely within the unit ball.

    Args:
        num_angular_weights (int):
            The number of components in the angular kernel
        radial_base_kernel (gpytorch.kernel):
            The base kernel for computing the radial kernel
        batch_size (int, optional):
            Set this if the data is batch of input data.
            It should be `b` if x1 is a `b x n x d` tensor. Default: `1`
        eps (float):
            Small floating point number used to improve numerical stability
            in kernel computations. Default: `1e-6`
        param_transform (function, optional):
            Set this if you want to use something other than softplus to ensure positiveness of parameters.
        inv_param_transform (function, optional):
            Set this to allow setting parameters directly in transformed space and sampling from priors.
            Automatically inferred for common transformations such as torch.exp or torch.nn.functional.softplus.
    """

    def __init__(
        self,
        num_angular_weights: int,
        radial_base_kernel: Kernel,
        eps: Optional[float] = 1e-6,
        angular_weights_prior: Optional[Prior] = None,
        angular_weights_constraint: Optional[Interval] = None,
        alpha_prior: Optional[Prior] = None,
        alpha_constraint: Optional[Interval] = None,
        beta_prior: Optional[Prior] = None,
        beta_constraint: Optional[Interval] = None,
        **kwargs,
    ):
        if angular_weights_constraint is None:
            angular_weights_constraint = Positive()

        if alpha_constraint is None:
            alpha_constraint = Positive()

        if beta_constraint is None:
            beta_constraint = Positive()

        super().__init__(**kwargs)
        self.num_angular_weights = num_angular_weights
        self.radial_base_kernel = radial_base_kernel
        self.eps = eps

        self.register_parameter(
            name="raw_angular_weights",
            parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, num_angular_weights)),
        )
        self.register_constraint("raw_angular_weights", angular_weights_constraint)
        self.register_parameter(name="raw_alpha", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)))
        self.register_constraint("raw_alpha", alpha_constraint)
        self.register_parameter(name="raw_beta", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)))
        self.register_constraint("raw_beta", beta_constraint)

        if angular_weights_prior is not None:
            if not isinstance(angular_weights_prior, Prior):
                raise TypeError("Expected gpytorch.priors.Prior but got " + type(angular_weights_prior).__name__)
            self.register_prior(
                "angular_weights_prior",
                angular_weights_prior,
                lambda m: m.angular_weights,
                lambda m, v: m._set_angular_weights(v),
            )
        if alpha_prior is not None:
            if not isinstance(alpha_prior, Prior):
                raise TypeError("Expected gpytorch.priors.Prior but got " + type(alpha_prior).__name__)
            self.register_prior("alpha_prior", alpha_prior, lambda m: m.alpha, lambda m, v: m._set_alpha(v))
        if beta_prior is not None:
            if not isinstance(beta_prior, Prior):
                raise TypeError("Expected gpytorch.priors.Prior but got " + type(beta_prior).__name__)
            self.register_prior("beta_prior", beta_prior, lambda m: m.beta, lambda m, v: m._set_beta(v))

    @property
    def angular_weights(self) -> Tensor:
        return self.raw_angular_weights_constraint.transform(self.raw_angular_weights)

    @angular_weights.setter
    def angular_weights(self, value: Tensor) -> None:
        if not torch.is_tensor(value):
            value = torch.tensor(value)

        self.initialize(raw_angular_weights=self.raw_angular_weights_constraint.inverse_transform(value))

    @property
    def alpha(self) -> Tensor:
        return self.raw_alpha_constraint.transform(self.raw_alpha)

    @alpha.setter
    def alpha(self, value: Tensor) -> None:
        self._set_alpha(value)

    def _set_alpha(self, value: Union[Tensor, float]) -> None:
        # Used by the alpha_prior
        if not isinstance(value, Tensor):
            value = torch.as_tensor(value).to(self.raw_alpha)
        self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value))

    @property
    def beta(self) -> Tensor:
        return self.raw_beta_constraint.transform(self.raw_beta)

    @beta.setter
    def beta(self, value: Tensor) -> None:
        self._set_beta(value)

    def _set_beta(self, value: Union[Tensor, float]) -> None:
        # Used by the beta_prior
        if not isinstance(value, Tensor):
            value = torch.as_tensor(value).to(self.raw_beta)
        self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value))

    def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> Tensor:

        x1_, x2_ = x1.clone(), x2.clone()
        # Jitter datapoints that are exactly 0
        x1_[x1_ == 0], x2_[x2_ == 0] = x1_[x1_ == 0] + self.eps, x2_[x2_ == 0] + self.eps
        r1, r2 = x1_.norm(dim=-1, keepdim=True), x2_.norm(dim=-1, keepdim=True)

        if torch.any(r1 > 1.0) or torch.any(r2 > 1.0):
            raise RuntimeError("Cylindrical kernel not defined for data points with radius > 1. Scale your data!")

        a1, a2 = x1.div(r1), x2.div(r2)
        if not diag:
            gram_mat = a1.matmul(a2.transpose(-2, -1))
            for p in range(self.num_angular_weights):
                if p == 0:
                    angular_kernel = self.angular_weights[..., 0, None, None]
                else:
                    angular_kernel = angular_kernel + self.angular_weights[..., p, None, None].mul(gram_mat.pow(p))
        else:
            gram_mat = a1.mul(a2).sum(-1)
            for p in range(self.num_angular_weights):
                if p == 0:
                    angular_kernel = self.angular_weights[..., 0, None]
                else:
                    angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p))

        with settings.lazily_evaluate_kernels(False):
            radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params)
        return radial_kernel.mul(angular_kernel)

    def kuma(self, x: Tensor) -> Tensor:
        alpha = self.alpha.view(*self.batch_shape, 1, 1)
        beta = self.beta.view(*self.batch_shape, 1, 1)

        res = 1 - (1 - x.pow(alpha) + self.eps).pow(beta)
        return res

    def num_outputs_per_input(self, x1: Tensor, x2: Tensor) -> int:
        return self.radial_base_kernel.num_outputs_per_input(x1, x2)
